#!/usr/bin/env python
# -*- coding: utf8 -*-

import sys

# REverse compiler for ST music modules pre 07 version.
# (c) 2011 by Mono

log = False

# Wyjatek rzucacny kiedy plik nie jest plikiem binarnym
class NotBinaryFileException (Exception):
	def __init__(self, file):
		self.file = file

# Wyjatek rzucany kiedy w pliku nie ma juz blokow binarnych
class NoMoreDataException (Exception):
	def __init__(self, file):
		self.file = file

# Blok pliku binarnego
class BinaryBlock:
	def __init__(self, file, requireHeader = False):
		b2w = lambda a, b: a+256*b
		
		buf = file.read(2)
		if len(buf) == 0:
			raise NoMoreDataException(file)
		arr = bytearray(buf)
		if arr[0] != 0xff and arr[1] != 0xff:
			if requireHeader:
				raise NotBinaryFileException(file)
		else:
			buf = file.read(2)
		arr = bytearray(buf)
		self.first = b2w(arr[0], arr[1])
		buf = file.read(2)
		arr = bytearray(buf)
		self.last = b2w(arr[0], arr[1])
		buf = file.read(self.last + 1 - self.first)
		self.content = bytearray(buf)

	def __str__(self):
		return "$%04x..$%04x" % (self.first, self.last)

# Plik binarny
class BinaryFile:
	def __init__(self, filename):
		self.blocks = []
		file = open(filename, "rb")
		try:
			chunk = BinaryBlock(file)	#BinaryBlock(file, True)
			self.blocks.append(chunk)
			while True:
				chunk = BinaryBlock(file)
				self.blocks.append(chunk)
		except NoMoreDataException:
			pass
		finally:
			file.close()

	def __str__(self):
		return self.blocks

class CompiledSong:
	def __init__(self, loop, data):
		self.loop = loop
		self.data = data
	
	def __str__(self):
		return "$%02x, %s" % (self.loop, self.data)

class CompiledInstrument:
	def __init__(self, volumeEnvelopeLength, frequencyEnvelopeLength, distortion, reserved, envelope):
		self.volumeEnvelopeLength = volumeEnvelopeLength
		self.frequencyEnvelopeLength = frequencyEnvelopeLength
		self.distortion = distortion
		self.reserved = reserved
		self.envelope = envelope
	
	def __str__(self):
		return "$%02x, $%02x, $%02x, $%02x" % (self.volumeEnvelopeLength, self.frequencyEnvelopeLength, self.distortion, self.reserved, self.envelope)

class CompiledPattern:
	def __init__(self, length, tempo, audctl, notes, frequencies, instruments):
		self.length = length
		self.tempo = tempo
		self.audctl = audctl
		self.notes = notes
		self.frequencies = frequencies
		self.instruments = instruments
	
	def __str__(self):
		return "$%02x, $%02x, $%02x, %s, %s, %s" % (self.length, self.tempo, self.audctl, self.notes, self.frequencies, self.instruments)

# Skompilowany modul ST
class CompiledModule:
	def __init__(self, address, data, songLoop, songLength):
		self.instruments = []
		self.patterns = []
		
		b2w = lambda a, b: a+256*b
		
		songData = data[0:songLength]
		self.song = CompiledSong( songLoop, songData )
		
		addressTableOffset = songLength
		offsets = []
		for a in range(0, 0x40+0x33):
			itemAddress = b2w(data[addressTableOffset+a*2], data[addressTableOffset+a*2+1])
			if itemAddress >= address and itemAddress <= address+len(data):
				offsets.append(itemAddress-address)
			else:
				break
		patternsCount = max(songData)+1
		instrumentsCount = a-patternsCount
		
		instrumentsAddressTableOffset = addressTableOffset
		offsets.append(instrumentsAddressTableOffset)
		patternsAddressTableOffset = instrumentsAddressTableOffset+instrumentsCount*2
		offsets.append(patternsAddressTableOffset)
		
		offsets.append(len(data))
		offsets = sorted(offsets)
		
		findNext = lambda addr: offsets[offsets.index(addr)+1]
		
		print "Instruments: %i" % instrumentsCount
		for n in range(0, instrumentsCount):
			startOffset = b2w(data[instrumentsAddressTableOffset+n*2], data[instrumentsAddressTableOffset+n*2+1]) - address
			stopOffset = findNext(startOffset)
			frequencyEnvelopeLength = data[startOffset+0]
			volumeEnvelopeLength = data[startOffset+1]
			distortion = data[startOffset+2]
			reserved = data[startOffset+3]
			envelope = data[startOffset+4:stopOffset]
			self.instruments.append( CompiledInstrument(volumeEnvelopeLength, frequencyEnvelopeLength, distortion, reserved, envelope) )
	
		print "Patterns: %i" % patternsCount
		for n in range(0, patternsCount):
			startOffset = b2w(data[patternsAddressTableOffset+n*2], data[patternsAddressTableOffset+n*2+1]) - address
			stopOffset = findNext(startOffset)
			
			patternLength = data[startOffset+0]
			patternTempo = data[startOffset+1]
			patternAudctl = data[startOffset+2]
			patternFreqOffset = startOffset+data[startOffset+3]
			offsets.append(patternFreqOffset)
			patternInstrOffset = startOffset+data[startOffset+4]
			offsets.append(patternInstrOffset)
			patternNoteOffset = startOffset+5
			offsets.append(patternNoteOffset)
			offsets = sorted(offsets)
			
			startOff = patternNoteOffset
			stopOff = findNext(patternNoteOffset)
			noteCntr = 0
			freqCntr = 0
			for o in range(startOff, stopOff):
				for b in range(0, 2):
					if b != 0:
						ch = data[o] << 4
					else:
						ch = data[o]
					for q in range(0, 4):
						if (ch & 0x80) != 0:
							freqCntr = freqCntr+1
						ch = ch << 1
					noteCntr = noteCntr+1
				if noteCntr >= patternLength and o+1 != stopOff:
					stopOff = o+1
					print "Bad pattern %i note block length - shrinking data at $%04x" % (n+1, stopOff)
					break
			patternNotes = data[startOff:stopOff]
			
			startOff = patternFreqOffset
			stopOff = findNext(patternFreqOffset)
			if startOff+freqCntr < stopOff:
				stopOff = startOff+freqCntr
				print "Bad pattern %i frequency block length - shrinking data at $%04x" % (n+1, stopOff)
			patternFreqs = data[startOff:stopOff]
			
			startOff = patternInstrOffset
			stopOff = findNext(patternInstrOffset)
			#print "$%04x,$%04x" % (startOff, stopOff)
			patternPos = -1
			for o in range(startOff, stopOff, 5):
				if data[o] < patternLength and data[o] > patternPos:
					patternPos = data[o]
				elif data[o] != 0x46 or o+1 != stopOff:
					stopOff = o
					print "Bad pattern %i instruments block length - shrinking data at $%04x" % (n+1, stopOff)
					break
			patternInstrs = data[startOff:stopOff]
			
			self.patterns.append( CompiledPattern(patternLength, patternTempo, patternAudctl, patternNotes, patternFreqs, patternInstrs) )
		
	def __str__(self):
		return "%s, %s, %s" % (self.song, self.instruments, self.patterns)

class SourceSong:
	def __init__(self, loop, data):
		self.loop = loop
		self.data = data

class SourcePattern:
	def __init__(self, length, tempo, audctl, notes, instruments):
		self.length = length
		self.tempo = tempo
		self.audctl = audctl
		self.notes = notes
		self.instruments = instruments

class SourceInstrument:
	def __init__(self, envelope, distortion, reserved):
		self.envelope = envelope 
		self.distortion = distortion
		self.reserved = reserved

class SourceEnvelope:
	def __init__(self, volumeLength, frequencyLength, data):
		self.volumeLength = volumeLength
		self.frequencyLength = frequencyLength
		self.data = data

# Modul ST w postaci zrodlowej
class SourceModule:
	def __init__(self, instruments, patterns, song):
		self.instruments = instruments
		self.patterns = patterns
		self.song = song
	
	def __str__(self):
		return "%s, %s, %s" % (self.instruments, self.patterns, self.song)

	def write(self, filename):
		file = open(filename, "wb")
		try:
			file.write("MUSIC ")
			
			print "Instruments: %i" % len(self.instruments)
			file.write(chr(len(self.instruments)))
			for n in range(0, len(self.instruments)):
				#print "#%d:" % n
				file.write(chr(n))
				file.write(chr(self.instruments[n].envelope.frequencyLength))
				file.write(chr(self.instruments[n].envelope.volumeLength))
				file.write(chr(self.instruments[n].distortion))
				file.write(chr(self.instruments[n].reserved))
				file.write(self.instruments[n].envelope.data)
			
			print "Patterns: %i" % len(self.patterns)
			file.write(chr(len(self.patterns)))
			for n in range(0, len(self.patterns)):
				#print "#%d:" % n
				file.write(chr(n))
				file.write(chr(self.patterns[n].length))
				file.write(chr(self.patterns[n].tempo))
				file.write(chr(self.patterns[n].audctl))
				file.write(self.patterns[n].notes)
				file.write(self.patterns[n].instruments)
			
			print "Song: %i" % len(self.song.data)
			file.write(chr(self.song.loop))
			file.write(chr(len(self.song.data)))
			file.write(self.song.data)
			
			file.flush()
		finally:
			file.close()

# Konwerter z postacji skompilowanej na zrodlowa
class Converter:
	def convert(self, module):
		instruments = []
		for i in range(0, len(module.instruments)):
			instrument = module.instruments[i]
			instrumentDistortion = instrument.distortion
			instrumentReserved = instrument.reserved
			instrumentVolumeEnvelopeLength = instrument.volumeEnvelopeLength
			instrumentFrequencyEnvelopeLength = instrument.frequencyEnvelopeLength
			instrumentEnvelope = instrument.envelope
			while len(instrumentEnvelope) < 0x3c:
				instrumentEnvelope.append(0)
			instruments.append( SourceInstrument( \
				SourceEnvelope( instrumentVolumeEnvelopeLength, instrumentFrequencyEnvelopeLength, instrumentEnvelope), \
				instrumentDistortion, instrumentReserved) )
		
		patterns = []
		for p in range(0, len(module.patterns)):
			pattern = module.patterns[p]
			patternLength = pattern.length
			patternTempo = pattern.tempo
			patternAudctl = pattern.audctl
			patternNotes = pattern.notes
			patternFreqs = pattern.frequencies
			patternInstrs = pattern.instruments
			
			sourceNotes = bytearray()
			sourceInstrs = bytearray()
			patternInstr = 0
			patternFreq = 0
			patternInstrSeq = [0xff, 0xff, 0xff, 0xff]
			for n in range(0, len(patternNotes)):
				for b in range(0, 2):
					if b != 0:
						ch = patternNotes[n] << 4
					else:
						ch = patternNotes[n]
					for q in range(0, 4):
						if (ch & 0x80) != 0:
							sourceNote = patternFreqs[patternFreq]+8
							sourceNotes.append(sourceNote)
							patternFreq = patternFreq+1
						else:
							sourceNotes.append(0)
						ch = ch << 1
					patternPos = n*2+b
					if patternInstr < len(patternInstrs) and patternPos == patternInstrs[patternInstr]:
						#print "$%02x,$%02x $%02x" % (len(patternInstrs), patternInstr, patternInstrs[patternInstr]) + " $%02x,$%02x,$%02x,$%02x" % tuple( [ int(x) for x in patternInstrs[patternInstr+1:patternInstr+5]] )
						patternInstrSeq = patternInstrs[patternInstr+1:patternInstr+5]
						patternInstr = patternInstr+5
					#elif patternInstr >= len(patternInstrs):
						#print "$%02x,$%02x" % (len(patternInstrs), patternInstr)
					#else:
						#print "$%02x,$%02x $%02x" % (len(patternInstrs), patternInstr, patternInstrs[patternInstr])
					for i in patternInstrSeq:
						sourceInstrs.append(chr( (i+1) & 0xff ))
			while len(sourceNotes) < 0x40*4:
				sourceNotes.append(chr(0))
			while len(sourceInstrs) < 0x40*4:
				sourceInstrs.append(chr(0))
			
			patterns.append( SourcePattern(patternLength, patternTempo, patternAudctl, sourceNotes, sourceInstrs) )

		songLoop = module.song.loop
		songData = module.song.data
		song = SourceSong(songLoop, songData)
		
		return SourceModule(instruments, patterns, song)

# Main entrance
if __name__ == "__main__":
	progname = sys.argv[0]
	if len(sys.argv) >= 4:
		modname = sys.argv[1]
		songloop = int(sys.argv[2])
		songlen = int(sys.argv[3])
		if len(sys.argv) > 4:
			sourcename = sys.argv[4]
		else:
			dotindex = modname.rfind(".")
			if dotindex > 0:
				sourcename = modname[:dotindex]
			else:
				sourcename = modname
			sourcename = sourcename + ".ST"
		print "Reading %s compiled file" % modname
		binfile = BinaryFile(modname)
		if len(binfile.blocks) > 0:
			if (len(binfile.blocks) > 1):
				print "File has %i blocks - get only first" % len(binfile.blocks)
			block = binfile.blocks[0]
			module = CompiledModule(block.first, block.content, songloop, songlen)
			converter = Converter()
			source = converter.convert(module)
			print "Writing %s source file" % sourcename
			source.write(sourcename)
		else:
			print "No blocks in file"
	else:
		print "REverse compiler for ST pre 07 music modules."
		print "(c) 2011 by Mono / Tristesse"
		print "\nUsage: %s compilename songloop songlen [sourcename]" % progname
		print "\nGenerates .ST source file from compiled file."

